from .baselib_dataset import BaseLibImageDataset
from .test_dataset import VesselTestDataset
from .videoreid_dataset import VesselImageDataset

__datasetFactory = {
    'train': VesselImageDataset,
    'test': VesselTestDataset,
    'baselib': BaseLibImageDataset
}


def get_names():
    return __datasetFactory.keys()


def init_dataset(mode, cfg):
    data_path_factory = {
        'train': cfg.params.TRAIN_FILE_PATH,
        'test': cfg.params.TEST_FILE_PATH,
        'baselib': cfg.params.BASELIB_FILE_PATH
    }
    if mode not in __datasetFactory.keys():
        raise KeyError("Unknown dataset: {}".format(mode))

    return __datasetFactory[mode](root=data_path_factory[mode])
